{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine tuning and creating TRT Model\n",
    "In this notebook we show how to include TensorRT (TRT) 3.0 in a typical deep learning development workflow. The aim is to show you how you can take advantage of TensorRT to dramatically speed up your inference in a simple and straightforward manner.\n",
    "\n",
    "In this example we will see how to fine tune a VGG19 architecture trained on Imagenet to categorize different kinds of flower in 5 classes. After fine tuning we will test the accuracy of the model and save it in a format that is understandeable by TensorRT.\n",
    "\n",
    "## Workflow\n",
    "\n",
    "In this notebook we explore the \"training\" aspect of this problem. For this reason we will need to have tensorflow and keras packages installed in addition to TensorRT 3.0 with python interface, UFF and other modules. Referring to the figure that has been shown at the beginning of the webinar, we will be tackling the \"training the model\" portion of the slide.\n",
    "\n",
    "\n",
    "### Imports\n",
    "In the following we import the necessary python packages for this part of the hands-on session\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Import Keras Modules '''\n",
    "from keras.applications.vgg19 import VGG19\n",
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.models import Model, load_model\n",
    "from keras.layers import Dense, GlobalAveragePooling2D\n",
    "from keras import backend as K\n",
    "import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Import Tensorflow Modules '''\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import graph_io\n",
    "from tensorflow.python.tools import freeze_graph\n",
    "from tensorflow.core.protobuf import saver_pb2\n",
    "from tensorflow.python.training import saver as saver_lib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the two cells above we have imported the python pacakges needed to perform training (fine tuning) In the first cell we have imported Keras with tensorflow backhand. In the second cell we have imported tensorflow and in particular all the routines necessary to save a frozen version of the model. \n",
    "\n",
    "### Training configuration\n",
    "In the following we specify configuration parameters that are needed for training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    # Training params\n",
    "    \"train_data_dir\": \"/home/data/train\",  # training data\n",
    "    \"val_data_dir\": \"/home/data/val\",  # validation data \n",
    "    \"train_batch_size\": 16,  # training batch size\n",
    "    \"epochs\": 3,  # number of training epochs\n",
    "    \"num_train_samples\" : 2936,  # number of training examples\n",
    "    \"num_val_samples\" : 734,  # number of test examples\n",
    "\n",
    "    # Where to save models (Tensorflow + TensorRT)\n",
    "    \"graphdef_file\": \"/home/model/keras_vgg19_graphdef.pb\",\n",
    "    \"frozen_model_file\": \"/home/model/keras_vgg19_frozen_model.pb\",\n",
    "    \"snapshot_dir\": \"/home/data/model/snapshot\",\n",
    "    \"engine_save_dir\": \"/home/model/\",\n",
    "    \n",
    "    # Needed for TensorRT\n",
    "    \"image_dim\": 224,  # the image size (square images)\n",
    "    \"inference_batch_size\": 1,  # inference batch size\n",
    "    \"input_layer\": \"input_1\",  # name of the input tensor in the TF computational graph\n",
    "    \"out_layer\": \"dense_2/Softmax\",  # name of the output tensorf in the TF conputational graph\n",
    "    \"output_size\" : 5,  # number of classes in output (5)\n",
    "    \"precision\": \"fp32\",  # desired precision (fp32, fp16)\n",
    "\n",
    "    \"test_image_path\" : \"/home/data/val/roses\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tuning\n",
    "In the following we show how to load the VGG19 model and finetune it with images from the flower dataset. Finally, after fine tuning, we will save the model as a frozen tensorflow graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune_and_freeze_model():\n",
    "    # create the base pre-trained model\n",
    "    base_model = VGG19(weights='imagenet', include_top=False, pooling='avg')\n",
    "\n",
    "    # add a global spatial average pooling layer\n",
    "    x = base_model.output\n",
    "    # let's add a fully-connected layer\n",
    "    x = Dense(1024, activation='relu')(x)\n",
    "    # and a softmax layer -- in this example we have 5 classes\n",
    "    predictions = Dense(5, activation='softmax')(x)\n",
    "\n",
    "    # this is the model we will finetune\n",
    "    model = Model(inputs=base_model.input, outputs=predictions)\n",
    "\n",
    "    # We want to use the convolutional layers from the pretrained \n",
    "    # VGG19 as feature extractors, so we freeze those layers and exclude \n",
    "    # them from training and train only the new top layers\n",
    "    for layer in base_model.layers:\n",
    "        print(layer.get_config())\n",
    "        layer.trainable = False\n",
    "\n",
    "    # compile the model (should be done *after* setting layers to non-trainable)\n",
    "    model.compile(\n",
    "        optimizer=keras.optimizers.Adam(1e-4), \n",
    "        loss='categorical_crossentropy', \n",
    "        metrics = ['accuracy']\n",
    "    )\n",
    "    \n",
    "    #create data generators for training/validation\n",
    "    train_datagen = ImageDataGenerator()\n",
    "    val_datagen = ImageDataGenerator()\n",
    "\n",
    "    train_generator = train_datagen.flow_from_directory(\n",
    "        directory=config['train_data_dir'],\n",
    "        target_size=(config['image_dim'], config['image_dim']), \n",
    "        batch_size=config['train_batch_size']\n",
    "    )\n",
    "\n",
    "    val_generator = val_datagen.flow_from_directory(\n",
    "        directory=config['val_data_dir'], \n",
    "        target_size=(config['image_dim'], config['image_dim']), \n",
    "        batch_size=config['train_batch_size']\n",
    "    )\n",
    "\n",
    "    # train the model on the new data for a few epochs\n",
    "    model.fit_generator(\n",
    "        train_generator, \n",
    "        steps_per_epoch=config['num_train_samples']//config['train_batch_size'], \n",
    "        epochs=config['epochs'], \n",
    "        validation_data=val_generator, \n",
    "        validation_steps=config['num_val_samples']//config['train_batch_size']\n",
    "    )\n",
    "\n",
    "    # Now, let's use the Tensorflow backend to get the TF graphdef and frozen graph\n",
    "    K.set_learning_phase(0)\n",
    "    sess = K.get_session()\n",
    "    saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)\n",
    "\n",
    "    # save model weights in TF checkpoint\n",
    "    checkpoint_path = saver.save(sess, config['snapshot_dir'], global_step=0, latest_filename='checkpoint_state')\n",
    "\n",
    "    # remove nodes not needed for inference from graph def\n",
    "    train_graph = sess.graph\n",
    "    inference_graph = tf.graph_util.remove_training_nodes(train_graph.as_graph_def())\n",
    "\n",
    "    # write the graph definition to a file. \n",
    "    # You can view this file to see your network structure and \n",
    "    # to determine the names of your network's input/output layers.\n",
    "    graph_io.write_graph(inference_graph, '.', config['graphdef_file'])\n",
    "\n",
    "    # specify which layer is the output layer for your graph. \n",
    "    # In this case, we want to specify the softmax layer after our\n",
    "    # last dense (fully connected) layer. \n",
    "    out_names = config['out_layer']\n",
    "\n",
    "    # freeze your inference graph and save it for later! (Tensorflow)\n",
    "    freeze_graph.freeze_graph(\n",
    "        config['graphdef_file'], \n",
    "        '', \n",
    "        False, \n",
    "        checkpoint_path, \n",
    "        out_names, \n",
    "        \"save/restore_all\", \n",
    "        \"save/Const:0\", \n",
    "        config['frozen_model_file'], \n",
    "        False, \n",
    "        \"\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run fine-tuning\n",
    "Fine-tuning will run for the specified number of epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_and_freeze_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
